# -*- encoding: utf-8 -*-

"""
@File:      basetest.py
@Time:      2022/05/12 16:18:47
@Author:    bolong.tbl
@Version:   1.0
@Contact:   bolong.tbl@alibaba-inc.com
@License:   Mulan PSL v2
"""

import time
from enum import Enum
from avocado import Test
from common.hosts import LocalHost, RemoteHost
from common.package import PackageManager
from avocado.core import exceptions

class BaseTest(Test):
    """
    BaseTest class for all testcase
    """

    def setUp(self):
        self.log.info('setup')

    def tearDown(self):
        self.log.info('teardown')

    def skip(self, message=None):
        raise exceptions.TestSkipError(message)

class LocalTest(BaseTest):
    """
    LocalTest class represents a testcase running on localhost

    Tips: if you pass remote host parameters to the testcase,
          you can also run the testcase on remote host
    """
    RPM_FLAG = {}
    def setUp(self, param_dic={}):
        super().setUp()
        self.local = LocalHost()
        self.remote = None
        if self.params.get('remote'):
            self.remote = RemoteHost(host=self.params.get('remote'),
                                     username=self.params.get('username'),
                                     port=self.params.get('port', default=22),
                                     key=self.params.get('key', default=None),
                                     password=self.params.get('password'))
        self.pkg_name = param_dic.get("pkg_name")
        if self.pkg_name:
            self.RPM_FLAG = PackageManager.setup_rpm_install(self, self.pkg_name)

    def cmd(self, command, host="local", ignore_status=False):
        if self.remote:
            return self.remote.cmd(command, ignore_status=ignore_status)
        return self.local.cmd(command, ignore_status=ignore_status)

    def tearDown(self, param_dic={}):
        super().tearDown()
        if self.pkg_name:
            PackageManager.rpm_uninstall(self, self.RPM_FLAG) 


class RemoteTest(BaseTest):
    """
    RemoteTest class represents a testcase running on remote host

    :param remote (str): remote host ip or hostname
    :param username (str): remote host username
    :param port (int): remote host port
    :param key (str): remote host key
    :param password (str): remote host password
    """
    
    RPM_FLAG = {}
    def setUp(self, param_dic={}):
        super().setUp()
        self.remote = RemoteHost(host=self.params.get('remote'),
                                 username=self.params.get('username'),
                                 port=self.params.get('port', default=22),
                                 key=self.params.get('key', default=None),
                                 password=self.params.get('password'))
        self.local = LocalHost()
        self.image = self._get_image_info()
        self.pkg_name = param_dic.get("pkg_name")
        if self.pkg_name:
            self.RPM_FLAG = PackageManager.setup_rpm_install(self, self.pkg_name, "remote")

    def cmd(self, command, host="remote", ignore_status=False):
        if host == 'remote':
            return self.remote.cmd(command, ignore_status=ignore_status)
        elif host == 'local':
            return self.local.cmd(command, ignore_status=ignore_status)

    def wait_ssh_connect(self, timeout=600):
        while timeout > 0:
            result = self.remote.session.connect()
            if result:
                self.log.info('ssh connect successfully')
                break
            else:
                time.sleep(1)
                timeout -= 1

    def _get_image_info(self):
        image = Image()
        ret_c, output = self.cmd('find /etc -name "image-id"')
        if output:
            image.env = Env.ECS
            ret_c, output = self.cmd('cat /etc/image-id | grep image_id')
            image.id = output.split('=')[1].strip('"')
        else:
            image.env = Env.PHY
        ret_c, output = self.cmd('cat /etc/os-release')
        for line in output.split('\n'):
            if line.startswith('ID='):
                image.ostype = OSType[line.split('=')[1].strip('"').upper()]
            elif line.startswith('VERSION_ID='):
                image.version = line.split('=')[1].strip('"')
        ret_c, image.arch = self.cmd('arch')
        ret_c, image.kernel = self.cmd('uname -r')
        self.log.debug('image.env: {}, image.ostype: {}, image.version: {}, image.arch: {}, image.kernel: {}, image.id: {}'.format(
            image.env, image.ostype, image.version, image.arch, image.kernel, image.id))
        return image

    def tearDown(self, param_dic={}):
        super().tearDown()
        if self.pkg_name:
            PackageManager.rpm_uninstall(self, self.RPM_FLAG, "remote")

class OSType(Enum):
    ALINUX = 'alinux'
    ANOLIS = 'anolis'
    ALIOS = 'alios'


class Env(Enum):
    ECS = 'ecs'
    PHY = 'phy'


class Image(object):

    def __init__(self):
        self.env = None
        self.id = None
        self.arch = None
        self.kernel = None
        self.ostype = None
        self.version = None
